import logging, wandb

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader

import torch.cuda as cuda
import torch.distributed as distributed
import torch.multiprocessing as multiprocessing
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.utils.data.distributed import DistributedSampler as Sampler

from ..dataset.dataset import *
from ..models.classifier import *
from .utils import *
from ..utils import *

logger = logging.getLogger(module_structure(__file__))

def train(cfg, *args, **kwargs):

    device = cfg.device

    trainset, testset = load_dataset(cfg.Data.path)

    logger.info(f'Dataset is built')

    if cfg.Data.balanced:
        trainset = _balanced_dataset(trainset, cfg.Data.amount_per_class, trainset.num_classes)
        testset = _balanced_dataset(testset, cfg.Data.amount_per_class, testset.num_classes)

    if cfg.ddp:
        trainsampler = Sampler(
            trainset, distributed.get_world_size(), distributed.get_rank(), shuffle = True, 
                seed = cfg.seed, drop_last = True)
        trainloader = DataLoader(trainset, sampler=trainsampler, batch_size=cfg.Classifier.batch_size, num_workers=4)

        if len(testset) > 0:
            testsampler = Sampler(
                testset, distributed.get_world_size(), distributed.get_rank(), shuffle = False, 
                    seed = cfg.seed, drop_last = True)
            testloader = DataLoader(testset, sampler=testsampler, batch_size=cfg.Classifier.batch_size, num_workers=4)
    else:
        trainloader = DataLoader(trainset, shuffle = True, batch_size=cfg.Classifier.batch_size, num_workers=4)
        if len(testset) > 0:
            testloader = DataLoader(testset, shuffle = True, batch_size=cfg.Classifier.batch_size, num_workers=4)

    logger.info(f"Dataloaders are built")
    

    if cfg.Classifier.Name == "LinearNet":
        net = LinearNet(cfg, *args, num_classes = trainset.num_classes, input_shape = trainset.input_shape, **kwargs)
    elif cfg.Classifier.Name == "ConvNet":
        net = ConvNet(cfg, *args, num_channels = trainset.num_channels, num_classes = trainset.num_classes, **kwargs)
    elif cfg.Classifier.Name == "ResNet":
        net = ResNet(cfg, *args, num_channels = trainset.num_channels, num_classes = trainset.num_classes, **kwargs)
    else:
        raise NotImplementedError(f'cfg.Classifier.Name: {cfg.Classifier.Name}')

    logger.info(f"Network is built")

    net = net.to(device)
    if cfg.ddp:
        net = DDP(net,device_ids=[cfg.local_rank])
        logger.info(f"Network is on DDP mode")
        optimizer = ZeRO(net.parameters(), optim.SGD, lr=cfg.Classifier.lr, weight_decay=cfg.Classifier.weight_decay)
    else:
        optimizer = optim.SGD(net.parameters(), lr=cfg.Classifier.lr, weight_decay=cfg.Classifier.weight_decay)


    criterion = nn.CrossEntropyLoss()
    # optimizer = ZeRO(net.parameters(), optim.Adam, lr=cfg.Classifier.lr, weight_decay=cfg.Classifier.weight_decay)

    logger.info(f"Training loop starts")
    for epoch in range(cfg.Classifier.epoches):  # loop over the dataset multiple times
        if cfg.ddp:
            distributed.barrier()
            trainsampler.set_epoch(epoch)
            testsampler.set_epoch(epoch)

        running_loss = 0.0
        for i, data in enumerate(trainloader, epoch * len(trainloader)):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % cfg.Classifier.evaluation_interval == (cfg.Classifier.evaluation_interval - 1):  
                averaged_training_loss = running_loss / cfg.Classifier.evaluation_interval
                logger.info(f'[epoch: {epoch + 1:6d}/{cfg.Classifier.epoches}, iter: {i + 1:5d}/{len(trainloader) * cfg.Classifier.epoches:5d}] loss: {averaged_training_loss:.6f}')
                running_loss = 0.0

                train_acc = accuracy(net, trainloader)
                logger.info(f'\t Acc on training set: {train_acc:5.3f}%')

                if len(testset) > 0:
                    test_acc = accuracy(net, testloader)
                    logger.info(f'\t Acc on testing set: {test_acc:5.3f}%')
                else:
                    test_acc = 0.0

                if cfg.wandb:
                    wandb.log({
                            "loss": averaged_training_loss,
                            "train_acc": train_acc,
                            "test_acc": test_acc,
                        }, i + 1)
                        
            if i % cfg.Classifier.save_interval == (cfg.Classifier.save_interval - 1):  
                torch.save(net.cpu().state_dict(), cfg.Classifier.path + f".{i}")
                logger.info(f'model is saved to {cfg.Classifier.path}'+ f".{i}")
                net = net.to(device)

    logger.info('Finished Training')
    if cfg.ddp:
        torch.save(net.module.cpu().state_dict(), cfg.Classifier.path)
        logger.info(f'model is saved to {cfg.Classifier.path}')
    else:
        torch.save(net.cpu().state_dict(), cfg.Classifier.path)
        logger.info(f'model is saved to {cfg.Classifier.path}')

    return {"acc": train_acc}
